{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "856edcf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "80e34dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys\n",
    "sys.path.insert(0,\"..\")\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import tqdm\n",
    "import sklearn, sklearn.metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "78ddd796",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchxrayvision as xrv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "032df440",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    Chest Radiograph Interpretation with Deep Learning Models: Assessment with \n",
      "    Radiologist-adjudicated Reference Standards and Population-adjusted Evaluation\n",
      "    Anna Majkowska, Sid Mittal, David F. Steiner, Joshua J. Reicher, Scott Mayer \n",
      "    McKinney, Gavin E. Duggan, Krish Eswaran, Po-Hsuan Cameron Chen, Yun Liu, \n",
      "    Sreenivasa Raju Kalidindi, Alexander Ding, Greg S. Corrado, Daniel Tse, and \n",
      "    Shravya Shetty. Radiology 2020\n",
      "    \n",
      "    https://pubs.rsna.org/doi/10.1148/radiol.2019191293\n",
      "    \n"
     ]
    }
   ],
   "source": [
    "print(xrv.datasets.NIH_Google_Dataset.__doc__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "67b3cfd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use XRV transforms to crop and resize the images\n",
    "transforms = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),\n",
    "                                             xrv.datasets.XRayResizer(224)])\n",
    "\n",
    "# Load Google dataset and PyTorch dataloader\n",
    "dataset = xrv.datasets.NIH_Google_Dataset(imgpath=\"/Users/ieee8023/Datasets/NIH/images-224\",\n",
    "                                          transform=transforms)\n",
    "dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)\n",
    "\n",
    "# Load pre-trained model and erase classifier\n",
    "model = xrv.models.DenseNet(weights=\"densenet121-res224-all\")\n",
    "model.op_threshs = None # prevent pre-trained model calibration\n",
    "model.classifier = torch.nn.Linear(1024,1) # reinitialize classifier\n",
    "\n",
    "optimizer = torch.optim.Adam(model.classifier.parameters()) # only train classifier\n",
    "criterion = torch.nn.BCEWithLogitsLoss()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47700846",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1e4309e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ad5cc0c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1d524194",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.7197195\n",
      "1 0.7032325\n",
      "2 0.7019479\n",
      "3 0.6871945\n",
      "4 0.7000455\n",
      "5 0.66800505\n",
      "6 0.6873612\n",
      "7 0.6366669\n",
      "8 0.5571896\n",
      "9 0.59965414\n",
      "10 0.5273169\n",
      "11 0.50774366\n",
      "12 0.5894309\n",
      "13 0.6490113\n",
      "14 0.6630047\n",
      "15 0.46392265\n",
      "16 0.43167824\n",
      "17 0.6394826\n",
      "18 0.6443106\n",
      "19 0.5077342\n",
      "20 0.57959396\n"
     ]
    }
   ],
   "source": [
    "# training loop (can run on cpu)\n",
    "for i, batch in enumerate(dataloader):\n",
    "    if i > 20: break\n",
    "    outputs = model(batch[\"img\"])\n",
    "    targets = batch[\"lab\"][:, dataset.pathologies.index(\"Lung Opacity\"), None]\n",
    "    loss = criterion(outputs, targets)\n",
    "    print(i, loss.detach().cpu().numpy())\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64107240",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "2c0e5e85",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "cc36469d",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(torch.from_numpy(sample[\"img\"]).unsqueeze(0))\n",
    "out = torch.sigmoid(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "140e37e0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3993]], grad_fn=<SigmoidBackward>)"
      ]
     },
     "execution_count": 130,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "f3eb3cf6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0 0.399319\n",
      "1.0 0.4504822\n",
      "1.0 0.46033397\n",
      "0.0 0.33570313\n",
      "0.0 0.39914772\n",
      "0.0 0.4095046\n",
      "0.0 0.42468312\n",
      "0.0 0.39061117\n",
      "1.0 0.43495628\n",
      "1.0 0.38757616\n",
      "0.0 0.45992658\n",
      "0.0 0.42239773\n",
      "1.0 0.47049743\n",
      "0.0 0.37503645\n",
      "0.0 0.39930966\n",
      "0.0 0.42051744\n",
      "0.0 0.37670723\n",
      "0.0 0.3471031\n",
      "0.0 0.3745227\n",
      "1.0 0.4336192\n"
     ]
    }
   ],
   "source": [
    "labels = []\n",
    "preds = []\n",
    "with torch.inference_mode():\n",
    "    for i in range(20):\n",
    "        sample = dataset[i]\n",
    "        label = sample[\"lab\"][dataset.pathologies.index(\"Lung Opacity\")]\n",
    "        labels.append(label)\n",
    "        pred = model(torch.from_numpy(sample[\"img\"]).unsqueeze(0))\n",
    "        pred = torch.sigmoid(pred).detach().numpy()[0][0]\n",
    "        preds.append(pred)\n",
    "        print(label, pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "c9b19bed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8571428571428572"
      ]
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sklearn.metrics.roc_auc_score(labels, preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f99a6c8c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "051afa80",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
